"""
Helpers for evaluating models.
"""

from .minibatchprox import MinibatchProx
from .variables import weight_decay
import numpy as np
# pylint: disable=R0913,R0914
def evaluate(sess,
             model,
             dataset,
             num_classes=5,
             num_shots=5,
             eval_inner_batch_size=5,
             eval_inner_iters=50,
             replacement=False,
             num_samples=10000,
             transductive=False,
             weight_decay_rate=1,
             lam_reg=0.1,
             MinibatchProx_m=MinibatchProx,
             dataset_name='tieredimagenet'):
    """
    Evaluate a model on a dataset.
    """
    metaminibatchprox = MinibatchProx_m(sess,
                         transductive=transductive,
                         pre_step_op=weight_decay(weight_decay_rate))
    total_correct = 0
    metaval_accuracies = []
    for _ in range(num_samples):
        correct_no = metaminibatchprox.evaluate(dataset, model.input_ph, model.label_ph,
                                          model.minimize_op, model.predictions,
                                          num_classes=num_classes, num_shots=num_shots,
                                          inner_batch_size=eval_inner_batch_size,
                                          inner_iters=eval_inner_iters, replacement=replacement,
                                          lam_reg=lam_reg,model=model,dataset_name = dataset_name)
        total_correct += correct_no
        metaval_accuracies.append(correct_no/num_classes)

    metaval_accuracies = np.array(metaval_accuracies)
    means = np.mean(metaval_accuracies, 0)
    stds = np.std(metaval_accuracies, 0)
    ci95 = 1.96 * stds / np.sqrt(num_samples)

    print('Mean validation accuracy/loss, stddev, and confidence intervals')
    print((means, stds, ci95))
    return total_correct / (num_samples * num_classes)
